[Pytorch] 详解 torch.cat() |
您所在的位置:网站首页 › r cat函数 › [Pytorch] 详解 torch.cat() |
1. 定义
官方手册中描述为: torch.cat(inputs, dimension=0) → Tensor在给定维度上对输入的张量序列seq 进行连接操作。 torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数可以通过下面例子更好的理解。 参数: inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列dimension (int, optional) – 沿着此维连接张量序列。 2. 例子 >>> x = torch.randn(2, 3) >>> x 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 2x3] >>> torch.cat((x, x, x), 0) 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 6x3] >>> torch.cat((x, x, x), 1) 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 2x9]torch.cat((x, x, x), 1)中的 0 or 1 就是指示的维度。 除此之外,可以指示为-1。 我将举几个例子 如图,a是2x3 b是2x5的一个张量 拼接后: 一句话总结:上下拼接要列数相同,左右拼接要行数相同。 另,用torch.cat拼接list里的tensor: 先整个list: 可以清楚的看到已经拼接好了,即参数可以直接传入一个seq |
今日新闻 |
点击排行 |
|
推荐新闻 |
图片新闻 |
|
专题文章 |
CopyRight 2018-2019 实验室设备网 版权所有 win10的实时保护怎么永久关闭 |